Skip to content

Inference | Hybrid prefix caching.#3225

Open
lmcafee-nvidia wants to merge 13 commits intoNVIDIA:mainfrom
lmcafee-nvidia:prefix-caching-mamba
Open

Inference | Hybrid prefix caching.#3225
lmcafee-nvidia wants to merge 13 commits intoNVIDIA:mainfrom
lmcafee-nvidia:prefix-caching-mamba

Conversation

@lmcafee-nvidia
Copy link
Contributor

@lmcafee-nvidia lmcafee-nvidia commented Feb 3, 2026

What does this PR do ?

Add Mamba state prefix caching for hybrid Transformer-Mamba models, enabling KV cache prefix sharing to also share corresponding Mamba conv/SSM states.

Key Features

  • Mamba state caching: When KV cache blocks are reused due to shared prefixes, the corresponding Mamba states can now also be cached and restored, avoiding redundant recomputation
  • Memory-budgeted cache: New --inference-dynamic-batching-prefix-caching-mamba-gb argument controls the memory budget for cached Mamba states
  • LRU eviction: Mamba cache uses LRU eviction based on KV block timestamps when the cache is full
  • Automatic invalidation: When KV blocks are evicted, their associated Mamba states are automatically invalidated

Changes

Core Implementation (megatron/core/inference/contexts/dynamic_context.py, megatron/core/inference/engines/dynamic_engine.py):

  • Fixed-size tensor pool for cached Mamba conv/SSM states
  • Block-to-slot mapping for associating KV blocks with Mamba cache slots
  • Store Mamba state at block boundaries after prefill completes
  • Restore cached Mamba state when prefix matches during request scheduling

Block Allocator (megatron/core/inference/contexts/dynamic_block_allocator.py):

  • Invalidate Mamba state when KV blocks are evicted

Arguments (megatron/training/arguments.py):

  • Add --inference-dynamic-batching-prefix-caching-mamba-gb parameter
  • Auto-enable chunked prefill when Mamba prefix caching is used

Test plan

  • Unit tests for cache allocation/deallocation
  • Unit tests for LRU eviction (20 tests in tests/unit_tests/inference/contexts/test_mamba_prefix_caching.py)
  • Tests for memory budget edge cases
  • Tests for state store/restore integration
  • End-to-end inference test with hybrid model

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact the @mcore-oncall.

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

@lmcafee-nvidia lmcafee-nvidia added this to the Core 0.15 milestone Feb 3, 2026
@lmcafee-nvidia lmcafee-nvidia self-assigned this Feb 3, 2026
@lmcafee-nvidia lmcafee-nvidia requested a review from a team as a code owner February 3, 2026 14:49
@lmcafee-nvidia lmcafee-nvidia added the Expert Review Apply this label to indicate that your PR is ready for expert review. label Feb 3, 2026
@lmcafee-nvidia lmcafee-nvidia requested review from a team as code owners February 3, 2026 14:49
@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 3, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@ko3n1g ko3n1g requested a review from a team February 3, 2026 14:50
Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>
Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>
Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>
Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>
Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>
Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>
Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>
Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>
@lmcafee-nvidia lmcafee-nvidia requested a review from a team as a code owner February 18, 2026 16:35
@lmcafee-nvidia lmcafee-nvidia force-pushed the prefix-caching-mamba branch 4 times, most recently from 27831b6 to e4a98ee Compare March 3, 2026 20:15
lmcafee-nvidia and others added 4 commits March 5, 2026 05:38
Bug fixes:
1. conv_state save-before-read: extract initial conv states BEFORE
   causal_conv1d_varlen_states + tensor_masked_update overwrites the
   conv_state buffer. Previously, initial_conv_states was read AFTER
   the buffer was updated, so restored requests would see their own
   newly-computed final states instead of the pre-existing initial
   states, corrupting the convolution output.

2. cu_chunk_seqlens OOB: the SSM Triton kernels allocate per-chunk
   output arrays of size chunk_size (128). Passing cu_seqlens directly
   as cu_chunk_seqlens caused out-of-bounds memory access when any
   sequence exceeded chunk_size tokens. Fix: subdivide each sequence
   into chunks of at most self.chunk_size, producing correct
   cu_chunk_seqlens boundaries.

3. zxBCdt padding mismatch: after conv1d, the per-request loop rebuilt
   xBC with only real tokens while dt and z retained padded token count.
   This caused a shape assertion failure in the SSM kernel. Fix: strip
   padded tokens from zxBCdt before _ssm_prefill, then pad the output
   back to the original padded size for downstream residual add.

4. Per-request conv1d with initial_states: causal_conv1d_fn cannot
   accept both seq_idx and initial_states simultaneously. The old code
   passed seq_idx to handle multiple sequences but this zeroes state at
   sequence boundaries instead of using the cached initial states. Fix:
   loop over requests, calling causal_conv1d_fn per-request with
   initial_states and channels-last layout.

Improvements:
- Unify all Mamba prefill (including chunked) through single varlen SSM
  kernel call, removing separate chunked-prefill routing and the
  _batch_indices_chunked_prefill / _device_chunked_prefill metadata
- Simplify _dynamic_inference to flat decode + prefill structure
- Add _dynamic_inference_prefill helper that strips CUDA-graph padding
  from metadata and data tensors before calling _ssm_prefill
- Remove deprecated constructor parameters (use_mem_eff_path, d_state,
  headdim, ngroups) and their warnings
- Add assertion format string in ssd_combined.py for easier debugging

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…raction

With chunk-aligned sequences (one sequence per chunk boundary), the final
SSM state for each sequence is simply states[last_chunk_indices], making
the separate chunk_state_varlen Triton kernel unnecessary. Construct
last_chunk_indices in mamba_mixer.py alongside cu_chunk_seqlens and
remove the cu_seqlens parameter from the varlen API since it was only
needed by chunk_state_varlen.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
For prefix caching of Mamba layers, extract SSM and conv states at
block-aligned chunk boundaries during varlen prefill. Since
block_size_tokens % chunk_size == 0, every block boundary falls on a
chunk boundary, making intermediate SSM state extraction pure indexing
with no extra computation. Conv states are sliced from the pre-conv
input tensor.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Instead of breaking prefill into chunks at Mamba state boundaries,
process full prefill in one kernel call and extract SSM/conv states
at specified token offsets. This eliminates the dependency on chunked
prefill scheduling and simplifies the engine.

Key changes:
- Two-map hash design: kv_hash_to_block_id + mamba_hash_to_block_id
- Mamba cache infrastructure: GPU memory pool for SSM/conv states
- Coupled prefix matching: skip tokens limited by Mamba match count
- Intermediate offset computation at KV divergence and last-aligned
- Engine passes mamba match count, commits states after forward pass
- KV eviction automatically invalidates Mamba state
- 18 new tests covering all Mamba caching paths

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

complexity: high Expert Review Apply this label to indicate that your PR is ready for expert review.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants